import javafx.util.Pair;

import java.util.*;

public class Solution1373 {

    class NodeInfo {
        int min;
        int max;
        int sum;

        public NodeInfo(int min, int max, int sum) {
            this.min = min;
            this.max = max;
            this.sum = sum;
        }
    }

    Map<TreeNode, NodeInfo> balanceTreeRootMap = new HashMap<>();

    private boolean isBalanceTreeRoot(TreeNode root){
        return root == null || balanceTreeRootMap.containsKey(root);
    }

    private boolean isBalance(TreeNode root, NodeInfo leftInfo, NodeInfo rightInfo){
        if(leftInfo != null && leftInfo.max >= root.val){
            return false;
        }
        if(rightInfo != null && rightInfo.min <= root.val){
            return false;
        }
        return true;
    }

    private NodeInfo getBalanceInfo(TreeNode root, NodeInfo leftInfo, NodeInfo rightInfo){
        int res = 0;
        List<Integer> sortList = new ArrayList<>();
        if(leftInfo != null){
            res += leftInfo.sum;
            sortList.add(leftInfo.min);
        }
        res += root.val;
        sortList.add(root.val);
        if(rightInfo != null){
            res += rightInfo.sum;
            sortList.add(rightInfo.max);
        }
        return new NodeInfo(sortList.get(0), sortList.get(sortList.size() - 1), res);
    }

    private void collectBalanceTreeRootInfo(TreeNode root){
        if(root == null){
            return ;
        }
        collectBalanceTreeRootInfo(root.left);
        collectBalanceTreeRootInfo(root.right);
        if(!isBalanceTreeRoot(root.left) || !isBalanceTreeRoot(root.right)){
            return ;
        }
        NodeInfo leftInfo = balanceTreeRootMap.get(root.left), rightInfo = balanceTreeRootMap.get(root.right);
        if(!isBalance(root, leftInfo, rightInfo)){
            return ;
        }
        balanceTreeRootMap.put(root, getBalanceInfo(root, leftInfo, rightInfo));
    }

    public int maxSumBST(TreeNode root) {
        collectBalanceTreeRootInfo(root);
        return Math.max(balanceTreeRootMap.entrySet().stream().map(e -> e.getValue().sum).max((a, b) -> {return a - b;}).orElse(0), 0);
    }

    static public TreeNode deserialize(String data) {
        if(data.equals("")){
            return null;
        }
        String[] strNodeArr = data.split(",");
        TreeNode root = new TreeNode(Integer.valueOf(strNodeArr[0]));
        Queue<TreeNode> q = new LinkedList<>();
        q.add(root);
        int ind = 1;
        while(!q.isEmpty()){
            TreeNode tmpNode = q.poll();
            String leftStr = strNodeArr[ind++];
            String rightStr = strNodeArr[ind++];
            if(!leftStr.equals("null")){
                tmpNode.left = new TreeNode(Integer.valueOf(leftStr));
                q.add(tmpNode.left);
            }
            if(!rightStr.equals("null")){
                tmpNode.right = new TreeNode(Integer.valueOf(rightStr));
                q.add(tmpNode.right);
            }
        }
        return root;
    }

    public static void main(String[] args) {
        TreeNode root = deserialize("-4,-2,-5,null,null,null,null");
        Solution1373 solution1373 = new Solution1373();
        System.out.println(solution1373.maxSumBST(root));
    }
}
